import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import cm
Blues = cm.get_cmap('Blues', 12)
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
import timeit
import math
import scipy
import scipy.linalg
import numpy.linalg
import scipy.special
import scipy.sparse
import scipy.sparse.linalg as ssla
import scipy.optimize
from mpmath import mp
from datetime import date
today = date.today()
date = today.strftime("%b%d%Y")

from sympy import *
from sympy.vector import CoordSys3D

hbar= 6.582119569509e-13 #meV*second
electronMass= 5.68563006e-27 # meV * (second/nm)^2
eSquaredOvere = 1439.964547  # meV * nm

mpl.rcParams['pdf.fonttype'] = 42
plt.rcParams.update({'font.size': 12})
plt.rcParams["figure.figsize"] = (2.00,2.00)
mpl.rcParams['font.family'] = 'Arial'
np.set_printoptions(precision=10,suppress=True,linewidth=1000)
from numba import njit
import numba

figDir='/home/areddy/plots/wignermolecule/'
dataDir='/home/areddy/data/wignermoleculedata/'

######

"""
WORKHORSE FUNCTIONS
"""

def GenDiagEnergy(ns,ls):
    return np.sum((2*ns+np.abs(ls)+1),axis=1)

@njit
def factorial(n): #for numba compatibility
    x = 1.0
    for m in range(1,n+1):
        x *= m
    return x

#@njit
def fockDarwin(n,l,rhoThetaVals):
    N=np.sqrt(math.factorial(n)/(np.pi*math.factorial(n+abs(l))))
    rhoVals=abs(rhoThetaVals[:,0])
    thetaVals=rhoThetaVals[:,1]
    R=rhoVals**(abs(l))*np.exp(-rhoVals**2/2)*scipy.special.genlaguerre(n,abs(l))(rhoVals**2)*(-1)**n #NOTE THAT THIS DEFINITION DIFFERS FROM THE CONVENTIONAL ONE IN THE LITERATURE BY THE FACTOR OF (-1)^n
    Phi=np.exp(+1j*l*thetaVals)
    FD=N*np.einsum('i,i->i',R,Phi)
    return(FD)
    
#@njit   
def natOrbWF(rhoThetaVals,natOrb,spBasis):
    nSp=spBasis.shape[0]
    psi=0
    for i in range(nSp):
        if abs(natOrb[i])>=1e-10:
            n,l=spBasis[i]
            psi+=natOrb[i]*fockDarwin(n,l,rhoThetaVals)
    return(psi)


@njit
def CME(n1pi,n1mi,n2pi,n2mi,n1pj,n1mj,n2pj,n2mj): # coulomb matrix element
    nsum = n1pi+n1mi+n2pi+n2mi+n1pj+n1mj+n2pj+n2mj
    dl = (n1pi+n2pi-n1mi-n2mi)-(n1pj+n2pj-n1mj-n2mj)
    deltaN2 = (n2pi+n2mi)-(n2pj+n2mj)
    deltaN1= (n1pi+n1mi)-(n1pj+n1mj)
    deltaNtot=deltaN2+deltaN1
    if dl!=0:
        return 0 #angular momentum conservation
    S1p = 0 #sum over k1p
    a1p = 1/factorial(n1pi)/factorial(n1pj) #prefactor of each term in S1p
    for k1p in range(min(n1pi,n1pj)+1):
        S1m = 0 #sum over k1m
        a1m = 1/factorial(n1mi)/factorial(n1mj)
        for k1m in range(min(n1mi,n1mj)+1):
            S2p = 0
            a2p = 1/factorial(n2pi)/factorial(n2pj)
            for k2p in range(min(n2pi,n2pj)+1):
                S2m = 0
                a2m = 1/factorial(n2mi)/factorial(n2mj)
                for k2m in range(min(n2mi,n2mj)+1):
                    p = nsum-2*(k1p+k1m+k2p+k2m)
                    I = 2**(-(p+3)/2)*math.gamma((p+1)/2)
                    S2m += a2m*I
                    a2m *= -(n2mi-k2m)*(n2mj-k2m)/(k2m+1)
                S2p += a2p*S2m
                a2p *= -(n2pi-k2p)*(n2pj-k2p)/(k2p+1)
            S1m += a1m*S2p
            a1m *= -(n1mi-k1m)*(n1mj-k1m)/(k1m+1)
        S1p += a1p*S1m
        a1p *= -(n1pi-k1p)*(n1pj-k1p)/(k1p+1)
    Vij = S1p*2*(-1)**(abs(deltaN2))*(1j)**(deltaNtot)*((-1)**abs(n1pj+n1mj+n2pj+n2mj))*np.sqrt(factorial(n1pi)*factorial(n1pj)*factorial(n1mi)*factorial(n1mj)*factorial(n2pi)*factorial(n2pj)*factorial(n2mi)*factorial(n2mj))
    return Vij

def generateCMatrixFDBasis(spBasis):
    nSp=spBasis.shape[0]
    CMat=np.zeros((nSp,nSp,nSp,nSp),dtype='complex')
    for i, si in enumerate(spBasis):
        ni,li=si
        for j, sj in enumerate(spBasis):
            nj,lj=sj
            for k, sk in enumerate(spBasis):
                nk,lk=sk
                for l, sl in enumerate(spBasis):
                    nl,ll=sl
                    CMat[i,j,k,l]+= AidanV(ni,li,nj,lj,nk,lk,nl,ll) # VMatrixElm AidanV(ni,li,nj,lj,nk,lk,nl,ll) #change here to decide between my or Trithep's CME function
    return(CMat)

def generateCMatrixAltBasis(AltBasis,CMatFDBasis):
    CMatAltBasis=np.einsum('im,jn,ko,lp,ijkl->mnop',np.conj(AltBasis),np.conj(AltBasis),AltBasis,AltBasis,CMatFDBasis)
    return(CMatAltBasis)

def AidanV(n1_prime,l1_prime,n2_prime,l2_prime,n1,l1,n2,l2):
    n1p_prime = n1_prime+(abs(l1_prime)+l1_prime)//2
    n1m_prime = n1p_prime-l1_prime
    n2p_prime = n2_prime+(abs(l2_prime)+l2_prime)//2
    n2m_prime = n2p_prime-l2_prime
    n1p = n1+(abs(l1)+l1)//2
    n1m = n1p-l1
    n2p = n2+(abs(l2)+l2)//2
    n2m = n2p-l2
    return CME(n1p_prime,n1m_prime,n2p_prime,n2m_prime,n1p,n1m,n2p,n2m)

def GenerateFDStates(N): # my version
    numBasisStates = int((N+1)*(N+2)/2)
    stateList = np.zeros((numBasisStates,2))
    index = 0
    for nP in range(N+1):
        for nM in range(N-nP+1):
            n=min(nP,nM)
            l=nP-nM
            stateList[index] = np.array([n,l])
            index += 1
    E0=2*stateList[:,0]+abs(stateList[:,1])
    L=stateList[:,1]
    ind_sort = np.argsort(E0+0.1*L/N)
    E0_sorted=E0[ind_sort] #sort primarily by energy and secondarily by orbital angular momentum
    stateList_sorted=stateList[ind_sort]
    #sort primarily by energy and then secondarily by angular momentum
    numStates=len(stateList)
    return(stateList_sorted.astype(int))

def GenerateFockStates(OneBodyStates,Nup,Ndn,L,mod=1000):
    from itertools import combinations
    ls = OneBodyStates[:,1]
    nstates = len(ls)
    states = []
    for UpStates in combinations(range(nstates),Nup):
        l_up = sum((ls[i] for i in UpStates))
        for DnStates in combinations(range(nstates,2*nstates),Ndn):
            l_dn = sum((ls[i-nstates] for i in DnStates))
            if (l_up+l_dn)%mod==L%mod: 
                states.append(list(UpStates)+list(DnStates))
    states = np.array(states)
    return states

def generateAltFockBasis(nSp,NU,ND):
    from itertools import combinations
    states = []
    for UpStates in combinations(range(nSp),NU):
        for DnStates in combinations(range(nSp,2*nSp),ND):
            states.append(list(UpStates)+list(DnStates))
    states = np.array(states)
    return states

def generateHspMatFD(fdBasis):
    Hsp0MatFD=2*fdBasis[:,0]+np.abs(fdBasis[:,1])+1
    HspMatFD=np.diag(Hsp0MatFD)
    return(HspMatFD)

# def generateH0(nSp,fockBasis,Hsp,oneParticleDiffDict):
#     nF=fockBasis.shape[0]
#     H0=np.zeros((nF,nF),dtype='complex')
#     for i, fi in enumerate(fockBasis):
#         for k in fi:
#             H0[i,i]+=Hsp[k%nSp,k%nSp]
#     for i in oneParticleDiffDict:
#         fi=fockBasis[i]
#         for j in oneParticleDiffDict[i]:
#             fj=fockBasis[j]
#             spa=set(fi).difference(set(fj)).pop() 
#             spb=set(fj).difference(set(fi)).pop()
#             if (spa>=nSp and spb>=nSp) or (spa<nSp and spb<nSp): #make sure same spin
#                 if spa < spb:
#                     betweenCount=np.count_nonzero((fj>spa)&(fj<spb))
#                 elif spb < spa:
#                     betweenCount=np.count_nonzero((fj>spb)&(fj<spa))
#                 sign=(-1)**(betweenCount)
#                 H0[i,j]+=Hsp[spa%nSp,spb%nSp]*sign
#     return H0

def generateH0(nSp,fockBasis,Hsp,oneParticleDiffDict):
    nF=fockBasis.shape[0]
    rows=[]
    cols=[]
    vals=[]
    for i, fi in enumerate(fockBasis):
        for k in fi:
            rows.append(i)
            cols.append(i)
            vals.append(Hsp[k%nSp,k%nSp])
    for i in oneParticleDiffDict:
        fi=fockBasis[i]
        for j in oneParticleDiffDict[i]:
            fj=fockBasis[j]
            spa=set(fi).difference(set(fj)).pop() 
            spb=set(fj).difference(set(fi)).pop()
            if (spa>=nSp and spb>=nSp) or (spa<nSp and spb<nSp): #make sure same spin
                if spa < spb:
                    betweenCount=np.count_nonzero((fj>spa)&(fj<spb))
                elif spb < spa:
                    betweenCount=np.count_nonzero((fj>spb)&(fj<spa))
                sign=(-1)**(betweenCount)
                rows.append(i)
                cols.append(j)
                vals.append(Hsp[spa%nSp,spb%nSp]*sign)
    H0 = scipy.sparse.csr_matrix((vals, (rows, cols)), shape=(nF,nF),dtype='complex')
    return H0

@njit
def kd(a,b):
    if a==b:
        return(1)
    else:
        return(0)

def calcap(fdBasis):
    nSp=fdBasis.shape[0]
    fdBasisnpnm=np.zeros((nSp,2),dtype=float)
    fdBasisnpnm[:,0]+=fdBasis[:,0]+(fdBasis[:,1]+abs(fdBasis[:,1]))/2
    fdBasisnpnm[:,1]+=fdBasis[:,0]+(-fdBasis[:,1]+abs(fdBasis[:,1]))/2
    ap=np.zeros((nSp,nSp),dtype=float)
    for i, si in enumerate(fdBasisnpnm):
            npi=si[0]
            nmi=si[1]
            for j, sj in enumerate(fdBasisnpnm):
                npj=sj[0]
                nmj=sj[1]
                ap[i,j]+=(kd(npi,npj-1)*kd(nmi,nmj)*np.sqrt(npj))
    return(ap)

def dag(O):
    return np.conjugate(np.transpose(O))

def calcam(fdBasis):
    nSp=fdBasis.shape[0]
    fdBasisnpnm=np.zeros((nSp,2),dtype=float)
    fdBasisnpnm[:,0]+=fdBasis[:,0]+(fdBasis[:,1]+abs(fdBasis[:,1]))/2
    fdBasisnpnm[:,1]+=fdBasis[:,0]+(-fdBasis[:,1]+abs(fdBasis[:,1]))/2
    am=np.zeros((nSp,nSp),dtype=float)
    #print('fdBasis (n,l), (np,nm):\n', np.stack((fdBasis,fdBasisnpnm),axis=1))
    for i, si in enumerate(fdBasisnpnm):
            npi=si[0]
            nmi=si[1]
            for j, sj in enumerate(fdBasisnpnm):
                npj=sj[0]
                nmj=sj[1]
                am[i,j]+=(kd(npi,npj)*kd(nmi,nmj-1)*np.sqrt(nmj))
    return(am)


def generateHspAltBasis(AltBasis, HspFDBasis):
    HspAltBasis=np.einsum('ij,jk,kl->il',np.conj(AltBasis).T,HspFDBasis,AltBasis)
    return(HspAltBasis)

def canonical_interaction(num_1p,i,j,k,l):
    i=i%num_1p
    j=j%num_1p    
    k=k%num_1p
    l=l%num_1p
    return min(*[(i,j,k,l),(j,i,l,k),(k,l,i,j),(l,k,j,i)])

def GenerateHint(FockStates,CMat):
    from collections import defaultdict
    combos = defaultdict(list)
    nstates = len(FockStates)
    num_1pstates = CMat.shape[0]
    StateVectors= np.zeros([FockStates.shape[0],nSp*2],dtype=int)
    StateIndex = dict()
    oneParticleDiffDict=dict()
    for i in range(nstates):
        StateIndex[tuple(FockStates[i])]=i
        StateVectors[i][FockStates[i]]=1
    for i in range(nstates):
        FockState = np.array(FockStates[i])
        for a in range(len(FockState)):
            for b in range(a+1,len(FockState)):
                alpha = FockState[a]
                beta = FockState[b]
                SameSpin = (beta<num_1pstates and alpha<num_1pstates) or (beta>=num_1pstates and alpha>=num_1pstates)
                if SameSpin:
                    direct = (alpha,beta,alpha,beta)
                    exchange = (beta,alpha,alpha,beta)
                    combos[canonical_interaction(num_1pstates,*direct)].append((1,(i,i)))
                    combos[canonical_interaction(num_1pstates,*exchange)].append((-1,(i,i)))
                else:
                    direct = (alpha,beta,alpha,beta)
                    combos[canonical_interaction(num_1pstates,*direct)].append((1,(i,i)))
        i_StateVector = StateVectors[i]
        delta = StateVectors[:,:] - i_StateVector[None,:]
        num_diffs = (np.sum(delta!=0,axis=1))
        one_particle_diff = np.where(num_diffs==2)[0]
        oneParticleDiffDict[i]=one_particle_diff
        two_particle_diff = np.where(num_diffs==4)[0]
        for j in one_particle_diff:
            j_StateVector = StateVectors[j]
            alpha = np.where(delta[j]==-1)[0][0]
            alpha_prime = np.where(delta[j]==1)[0][0]
            for beta in FockState:
                if beta == alpha: continue
                SameSpin = (beta<num_1pstates and alpha<num_1pstates) or (beta>=num_1pstates and alpha>=num_1pstates)
                if SameSpin:
                    direct = (alpha_prime,beta,alpha,beta)
                    exchange = (beta,alpha_prime,alpha,beta)
                    sign = (-1)**(abs(np.sum(i_StateVector[:alpha])-np.sum(j_StateVector[:alpha_prime])))
                    combos[canonical_interaction(num_1pstates,*direct)].append((sign,(j,i)))
                    combos[canonical_interaction(num_1pstates,*exchange)].append((-sign,(j,i)))
                else:
                    direct = (alpha_prime,beta,alpha,beta)
                    sign = (-1)**(abs(np.sum(i_StateVector[:alpha])-np.sum(j_StateVector[:alpha_prime])))
                    combos[canonical_interaction(num_1pstates,*direct)].append((sign,(j,i)))
        for j in two_particle_diff:
            j_StateVector = StateVectors[j]
            alpha,beta = np.where(delta[j]==-1)[0]
            alpha_prime,beta_prime = np.where(delta[j]==1)[0]
            SameSpin = (alpha < num_1pstates and beta < num_1pstates) or (alpha >= num_1pstates and beta >= num_1pstates)
            if SameSpin:
                direct = (alpha_prime,beta_prime,alpha,beta)
                exchange = (beta_prime,alpha_prime,alpha,beta)
                sign = (-1)**(abs(np.sum(i_StateVector[:alpha])-np.sum(j_StateVector[:alpha_prime])
                                 +np.sum(i_StateVector[:beta])-np.sum(j_StateVector[:beta_prime])))
                combos[canonical_interaction(num_1pstates,*direct)].append((sign,(j,i)))
                combos[canonical_interaction(num_1pstates,*exchange)].append((-sign,(j,i)))
            else:
                direct = (alpha_prime,beta_prime,alpha,beta)
                sign = (-1)**(abs(np.sum(i_StateVector[:alpha])-np.sum(j_StateVector[:alpha_prime])
                                 +np.sum(i_StateVector[:beta])-np.sum(j_StateVector[:beta_prime])))
                combos[canonical_interaction(num_1pstates,*direct)].append((sign,(j,i)))
    rows=[]
    cols=[]
    vals=[]
    for stateindices,sign_jis in combos.items():
        a,b,c,d = stateindices
        v=CMat[a,b,c,d]
        for sign,(j,i) in sign_jis:
            rows.append(j)
            cols.append(i)
            vals.append(sign*v)
    Hint = scipy.sparse.csr_matrix((vals, (rows, cols)), shape=(nstates,nstates),dtype='complex')
    return Hint, oneParticleDiffDict
    

def oneBodyDMat(eState,oneParticleDiffDict,spBasis,fBasis):
    nSp=spBasis.shape[0]
    D=np.zeros((2,nSp,nSp),dtype='complex')
    for i,fi in enumerate(fBasis):
        #diagonal in fock basis
        for sp in fi:
            s=abs(sp//nSp)
            D[s,sp%nSp,sp%nSp]+=np.conj(eState[i])*eState[i]
    for i in oneParticleDiffDict:
        fi=fBasis[i]
        #off-diagonal in fock basis
        for j in oneParticleDiffDict[i]:
            fj=fBasis[j]
            spa=set(fi).difference(set(fj)).pop() 
            spb=set(fj).difference(set(fi)).pop()
            if (spa>=nSp and spb>=nSp) or (spa<nSp and spb<nSp): #make sure same spin
                if spa < spb:
                    betweenCount=np.count_nonzero((fj>spa)&(fj<spb))
                elif spb < spa:
                    betweenCount=np.count_nonzero((fj>spb)&(fj<spa))
                sign=(-1)**(betweenCount)
                s=abs(spa//nSp)
                D[s,spa%nSp,spb%nSp]+=np.conj(eState[j])*eState[i]*sign
    return(D)


######

"""
Strained moiré Wigner molecule
"""

phi = 20.0 * np.pi/180
V = 9.65/np.cos(phi) #meV
eps = 5 #5
mStar = 0.9
aM = 9.8 #nm
L = 0 # orbital angular momentum of state (mod rotation symmetry breaking)

N = 9 #9 #maximum shell in single particle basis

nTaylor = 6 # go to order r^(nTaylor)

stretchfactor = 0.0
stretchangle = 0 #degrees

if stretchfactor == 0:
    mod = 3
else:
    mod = 1 #angular momentum not conserved for generic strain

NU = 3
ND = 0

print("NU, ND:",NU,ND)
print("N:", N)

a1= aM*np.array([1.0, 0, 0]) #np.array([11.04, 0, 0]) #LONG AM REGION
a2= aM*np.array([1/2, np.sqrt(3)/2, 0]) #np.array([3.411, 7.883, 0])
theta=stretchangle*(np.pi/180)
stretchMat=np.array([[1+stretchfactor/2,0,0],[0,1-stretchfactor/2,0],[0,0,0]])
rotMat = np.array([[np.cos(theta), -np.sin(theta), 0],[np.sin(theta), np.cos(theta), 0],[0,0,1]])
rotInv =  np.array([[np.cos(-theta), -np.sin(-theta), 0],[np.sin(-theta), np.cos(-theta), 0],[0,0,1]])
lintrans = rotMat @ stretchMat @ rotInv
a1 = lintrans @ a1
a2 = lintrans @ a2

ez = np.array([0,0,1])

A = np.cross(a1, a2)[2]

b1 = 2*np.pi*(np.cross(a2, ez)) / A
b2 = 2*np.pi*(np.cross(ez, a1)) / A

(b1x, b1y) = b1[:2]

(b2x, b2y) = b2[:2]

print((b1x, b1y))
print((b2x, b2y))

print(np.sqrt(b2x**2 + b2y**2)/(4*np.pi / (np.sqrt(3))))

#aM = np.sqrt(A*2/np.sqrt(3)) #4*np.pi / (np.sqrt(3)) / ((np.sqrt(b1x**2 + b1y**2) + np.sqrt(b2x**2 + b2y**2)) / 2)

print("aM:", aM)

#useCFBasis==True

paramString='WigMolNU%dND%dN%deps%.2fV=%.2fmStar=%.2faM=%.2fphi=%.2fb1=(%.2f, %.2f)b2=(%.2f,%.2f)ntaylor%dstretchfactor%.2fangle%.2f%s'%(NU,ND,N,eps,V,mStar,aM,phi*(180/np.pi),b1x,b1y,b2x,b2y,nTaylor,stretchfactor,stretchangle,date)

FDBasis=GenerateFDStates(N)

nSp=FDBasis.shape[0]
print('nSp:', nSp)

tic=timeit.default_timer()
FDBasisEnlarged=GenerateFDStates(N+2*nTaylor) #enlarge basis to allow for higher order matrix elements to be treated properly
for i  in range(nSp):
    if np.array_equal(FDBasis[i],FDBasisEnlarged[i])==False:
        print('enlarged and truncated FB basis mismatch!')

#SET UP CRYSTAL FIELD

vec = CoordSys3D('vec')

#Definging operators
ap=calcap(FDBasisEnlarged)
am=calcam(FDBasisEnlarged)
rp=ap+dag(am)
rm=dag(ap)+am
pp=(ap-dag(am))/2j
pm=(am-dag(ap))/2j
x=((rp+rm)/2).astype('complex')
y=(rp-rm)/(2j)

HspKineticMatFD = (4*pm@pp)/2

B1x, B1y, B2x, B2y, Phi, X, Y = symbols('B1x B1y B2x B2y Phi X Y')

rCart = X*vec.i +  Y*vec.j
B1 = B1x*vec.i + B1y*vec.j
B2 = B2x*vec.i + B2y*vec.j
B3 = -(B1 + B2)

mPCart = -2 * V * (cos(rCart.dot(B1) + Phi) + cos(rCart.dot(B2) + Phi) + cos(rCart.dot(B3) + Phi))

HspCFMatFD = np.zeros_like(x)

#x axis spring constant

kAvg = float((diff(mPCart, X, 2) + diff(mPCart, Y, 2)).subs([(X, 0), (Y, 0), (Phi, phi), (B1x, b1x), (B1y, b1y), (B2x, b2x), (B2y, b2y)])) / 2

print("kAvg", kAvg)

hbaromega = hbar * np.sqrt(kAvg/(mStar * electronMass)) #Hamiltonian will be in units of hbaromega

l = np.sqrt( hbar**2 / (mStar * electronMass * hbaromega) )

Lambda = eSquaredOvere / (eps * l) / hbaromega

xic  = (Lambda/4)**(1/3) * l 

print("l, hbaromega, lambda, eps, xic, l/aM:", l, hbaromega, Lambda, eps, xic, (l/aM))

coeffArray = np.zeros((nTaylor+1,nTaylor+1))

for n in range(nTaylor+1):
    for m in range(nTaylor+1):
        if n+m <= nTaylor:
            coeff = diff(diff(mPCart, X, n), Y, m).subs([(X, 0), (Y, 0), (Phi, phi), (B1x, b1x), (B1y, b1y), (B2x, b2x), (B2y, b2y)]) * (l)**(n+m) / factorial(n) / factorial(m) / hbaromega
            coeffArray[n,m] = coeff
            #print(np.abs((numpy.linalg.matrix_power(x, n) @ numpy.linalg.matrix_power(y, m))))
            HspCFMatFD += np.real(complex(coeff)) * (numpy.linalg.matrix_power(x, n) @ numpy.linalg.matrix_power(y, m))

print("coeffArray[2,2]: ", coeffArray[2,2])

#HspCFMatFD += 100 * numpy.linalg.matrix_power(x, 2) #for checking consistent definition of X/Y
 
# print("diag(real(HspCFMatFD)): \n", np.diag(np.real(HspCFMatFD)))
# print("real(HspKineticMatFD): \n", np.real(HspKineticMatFD))

toc=timeit.default_timer()
print('crystal field time (s):', toc-tic)

HspMatFD = (HspKineticMatFD + HspCFMatFD)[0:nSp,0:nSp] #trim excess

# print("real(HspMatFD): \n", np.real(HspMatFD))

eValsSP=np.linalg.eigvalsh(HspMatFD)
uniqueEVals, Degens = np.unique(np.around(eValsSP,12),return_counts=True)
summary=np.stack((uniqueEVals, Degens), axis=1)
print('single particle eVals and degeneracy:\n', summary[0:8])

fockBasisFD = GenerateFockStates(FDBasis,NU,ND,L,mod)
print('dim(H):', fockBasisFD.shape[0])
tic=timeit.default_timer()
CMatFD = generateCMatrixFDBasis(FDBasis)

toc=timeit.default_timer()
print('Coul mat elts time (s):', toc-tic)
tic=timeit.default_timer()
Hint, oneParticleDiffDictFD = GenerateHint(fockBasisFD,CMatFD)
Hint *= Lambda

toc=timeit.default_timer()
print('Hint time (s):', toc-tic)
tic=timeit.default_timer()
H0=generateH0(nSp,fockBasisFD,HspMatFD,oneParticleDiffDictFD)
toc=timeit.default_timer()
print('H0 time (s):', toc-tic)

print()
#print('H0:\n', H0)
Es=[]
gsVecs=[]

tic=timeit.default_timer()
H = Hint  + H0
#tic=timeit.default_timer()
eVal, eVec = ssla.eigsh(H,k=1,which='SA',maxiter=1e10, return_eigenvectors=True)
Es.append(eVal[0])
gs=eVec[:,0]


GSEvecSaveName=("GSEvec"+paramString)
np.save(dataDir+GSEvecSaveName, gs)

rho=oneBodyDMat(gs,oneParticleDiffDictFD,FDBasis,fockBasisFD)
natEVals, natEVecs=np.linalg.eigh(rho)
#print('natEVals:\n', natEVals)
uniqueNotOEVals, NatODegens = np.unique(np.around(natEVals,12),return_counts=True)
summary=np.flip(np.stack((uniqueNotOEVals, NatODegens), axis=1),axis=0)
#print('natEVals and degeneracy:\n', summary[0:5])
gsVecs.append(gs)
toc=timeit.default_timer()
print('diag time:\n', toc-tic)


######


"""
PLOT CHARGE DENSITY OF MANY BODY STATE
"""

def computenofr(eState,oneParticleDiffDict,fockBasis,FDBasis,xMax,numx):
    cutoff = 1e-7
    xVals=np.linspace(-xMax,xMax,numx)
    nSp=np.shape(FDBasis)[0]
    nofr=np.zeros((numx**2),dtype=complex)
    rhoThetaVals=np.zeros((int((numx**2)),2))
    for i,x in enumerate(xVals):
        for j, y in enumerate(xVals):
            rho=np.sqrt(x**2+y**2)
            theta=2*np.arctan(x/(y+rho)) # works for smooth definiton of theta
            rhoThetaVals[j+i*numx]+=np.array([rho,theta])
    for i,fi in enumerate(fockBasis):
        if abs(eState[i])>cutoff:
            #diagonal in fock basis
            for sp in fi:
                n,l=FDBasis[sp%nSp]
                FD=fockDarwin(n,l,rhoThetaVals)
                nofr+=abs(eState[i]*FD)**2
    for i in oneParticleDiffDict:
        if abs(eState[i])>cutoff: #save computation time by neglecting negligible terms
            fi=fockBasis[i]
            #off-diagonal in fock basis
            for j in oneParticleDiffDict[i]:
                if abs(eState[j])>cutoff:
                    fj=fockBasis[j]
                    #print('hit!')
                    spa=set(fi).difference(set(fj)).pop() 
                    spb=set(fj).difference(set(fi)).pop()
                    if (spa>=nSp and spb>=nSp) or (spa<nSp and spb<nSp): #make sure same spin
                        if spa < spb:
                            betweenCount=np.count_nonzero((fj>spa)&(fj<spb))
                        elif spb < spa:
                            betweenCount=np.count_nonzero((fj>spb)&(fj<spa))
                        sign=(-1)**(betweenCount)
                        ni,li= FDBasis[spa%nSp]
                        nj,lj= FDBasis[spb%nSp]
                        nofr+=np.conj(eState[j]*fockDarwin(ni,li,rhoThetaVals))*eState[i]*fockDarwin(nj,lj,rhoThetaVals)*sign
    nofr=np.real(nofr.reshape(numx,numx))
    return(nofr)

plt.rcParams["figure.figsize"] = (2.00,2.00)

xMax = 6 # nm
numx = 100

nofr = computenofr(gsVecs[0],oneParticleDiffDictFD,fockBasisFD,FDBasis,xMax/l,numx)

plt.rcParams["figure.figsize"] = (2.00,2.00)

ax=plt.gca()
colormap='magma'

ntot = np.sum(nofr) * (2*(xMax/l)/numx)**2

print("integrated charge density: np.sum(nofr):", np.sum(ntot))

ATimesnofr = nofr * (A/l**2)

ATimesnofrSaveName=("ATimesnofr"+paramString+"xMax%.2fnumx%dunitsnm"%(xMax,numx))
np.save(dataDir+ATimesnofrSaveName, ATimesnofr)

imshow = ax.imshow(ATimesnofr, cmap = colormap, origin = 'lower', extent = [-xMax,xMax,-xMax,xMax], interpolation="gaussian")
#plt.contour(np.abs(nofr), origin='lower', color='blue')
cMin = 0
cMax = np.amax(ATimesnofr)
imshow.set_clim(cMin, cMax)
plt.ylabel(r'$y$ $(nm)$')
plt.xlabel(r'$x$ $(nm)$')
#plt.xticks([-4,-2,0,2,4])
#plt.yticks([-4,-2,0,2,4])
#plt.title(r'$l^2n(\mathbf{r})$')
cbar = plt.colorbar(imshow, fraction=0.046, pad=0.04)
cbar.ax.set_ylabel(r'$A_{u.c.}n(\mathbf{r})$')
plt.savefig(figDir+'WigMolCharge'+paramString+'.pdf',bbox_inches='tight')
plt.show()

print("completed")

######